Skip to content

kv-cache : refactor + add llama_memory_state_i #13746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 31, 2025

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented May 24, 2025

cont #13706 (comment), #13194

Main goal here is to simplify the abstract interface of struct llama_kv_cache.

Overview

Changes to the internal struct llama_kv_cache abstract interface:

  • Remove llama_kv_cache::commit()
  • Remove llama_kv_cache::restore()
  • Remove llama_kv_cache::sbatch_init()
  • Remove llama_kv_cache::ubatch_next()
  • Remove llama_kv_cache::find_slot()
  • Remove llama_kv_cache_guard
  • Add:
--- llama-memory.h

    // the interface for managing the memory state during batch processing
    // this interface is implemented per memory type. see:
    //   - llama_kv_cache_unified_state
    //   - llama_kv_cache_unified_iswa_state
    //   ...
    //
    // the only method that can mutate the memory and the memory state is llama_memory_i::apply()
    //
    // TODO: rename to llama_memory_context_i ?
    class llama_memory_state_i {
    public:
        virtual ~llama_memory_state_i() = default;
    
        // consume the current ubatch from the state and proceed to the next one
        // return false if we are done
        virtual bool next() = 0;
    
        // apply the memory state for the current ubatch to the memory object
        // return false on failure
        virtual bool apply() = 0;
    
        // TODO: this might get reworked in the future when refactoring llama_batch
        virtual std::vector<int64_t> & out_ids() = 0;
    
        // get the current ubatch
        virtual const llama_ubatch & get_ubatch() const = 0;
    
        // get the status of the memory state
        virtual llama_memory_status get_status() const = 0;
    };
    
    using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;

--- llama-kv-cache.h

    // split the input batch into a set of ubatches and verify that they can fit into the cache
    // return a state object containing the ubatches and KV cache state required to process them
    // check the llama_memory_state_i::get_status() for the result
    virtual llama_memory_state_ptr init_batch(
            const llama_batch & batch,
            uint32_t n_ubatch,
            bool embd_pooled,
            bool logits_all) = 0;

    // simulate full cache, used for allocating worst-case compute buffers
    virtual llama_memory_state_ptr init_full() = 0;

This new interface changes the logic in llama_decode() to first make sure that we can fit the input batch into the cache and only after that we start to process the ubatches. This check takes correctly into account SWA masking and also makes sure that the cache will not be modified before we start the actual computation.

note: the latter is not yet true for the recurrent cache - see comments in the code

Another important update in this PR is that the find_slot() logic for unified caches is now improved. Before we looked for a slot (i.e. a set of contiguous cells) that is empty in order to place the ubatch in it. We now allow the slot to contain data from the same or other sequence which is masked (either by causality or by SWA):

// keep track of what the minimum sequence positions would be if we accept the ubatch
llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
seq_pos_min[s] = cells.seq_pos_min(s);
}
bool found = true;
for (uint32_t i = 0; i < n_tokens; i++) {
const llama_pos pos = ubatch.pos[i];
const llama_seq_id seq_id = ubatch.seq_id[i][0];
// can we use this cell? either:
// - the cell is empty
// - the cell is occupied only by one sequence:
// - mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos
bool can_use = cells.is_empty(head_cur + i);
if (!can_use && cells.seq_count(head_cur + i) == 1) {
const llama_pos pos_cell = cells.pos_get(head_cur + i);
// causal mask
if (cells.seq_has(head_cur + i, seq_id)) {
can_use = pos_cell >= pos;
}
if (!can_use) {
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
// SWA mask
if (pos_cell == seq_pos_min[seq_id_cell] &&
is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
seq_pos_min[seq_id_cell]++;
can_use = true;
}
}
}
if (!can_use) {
found = false;
head_cur += i + 1;
n_tested += i + 1;
break;
}
}

This change is needed for the next PR, which will optimize the SWA cache to use just n_swa + n_ubatch cells and it also has some other nice properties. For example, we no longer have to explicitly prune tokens on successful batch processing, which simplifies the logic significantly and allows us to re-enable speculative decoding for SWA models (will be done also in the next PR).

The worst-graph reserve logic is also refactored and simplified significantly.

There are also some changes to llama-batch, but these are mainly to patch things up so that we are able to push the KV cache refactor first. So no need to review the llama-batch in deep details - the code there will be reworked soon.

TODO

  • Adapt the recurrent cache to the new interface
  • Test optimization workflow

Next PRs

@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch from d23f887 to 8323e23 Compare May 24, 2025 14:06
Base automatically changed from gg/kv-cache-simplify-part2 to master May 25, 2025 13:34
@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch from c1434b8 to 1eec34a Compare May 25, 2025 13:42
@ggerganov ggerganov marked this pull request as ready for review May 25, 2025 14:50
@ggerganov ggerganov requested a review from ngxson as a code owner May 25, 2025 14:50
@ggerganov
Copy link
Member Author

This PR should not cause any performance changes and the numerical results should be mostly the same (with some small exceptions due to the new logic in find_slot()).

Would appreciate some testing and reports for regressions. Thanks.

@ggerganov ggerganov requested a review from slaren May 25, 2025 14:52
@ngxson
Copy link
Collaborator

ngxson commented May 25, 2025

I re-run the ppl test from #13194 (comment)

master at aa50ba4

OK:   Final estimate: PPL = 7.8002 +/- 0.17654   ggml-org/gemma-3-4b-it-GGUF:Q4_K_M
OK:   Final estimate: PPL = 37.6848 +/- 1.03389   bartowski/gemma-2-9b-it-GGUF:Q4_K_M
OK:   Final estimate: PPL = 5.9658 +/- 0.11216   lmstudio-community/Phi-3.1-mini-128k-instruct-GGUF:Q4_K_M
OK:   Final estimate: PPL = 5.2653 +/- 0.09581   bartowski/CohereForAI_c4ai-command-a-03-2025-GGUF:IQ1_M
OK:   Final estimate: PPL = 7.3320 +/- 0.16048   unsloth/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S

This PR:

OK:   Final estimate: PPL = 7.8003 +/- 0.17654   ggml-org/gemma-3-4b-it-GGUF:Q4_K_M
OK:   Final estimate: PPL = 37.6620 +/- 1.03339   bartowski/gemma-2-9b-it-GGUF:Q4_K_M
OK:   Final estimate: PPL = 5.9658 +/- 0.11216   lmstudio-community/Phi-3.1-mini-128k-instruct-GGUF:Q4_K_M
OK:   Final estimate: PPL = 5.2642 +/- 0.09577   bartowski/CohereForAI_c4ai-command-a-03-2025-GGUF:IQ1_M
OK:   Final estimate: PPL = 7.3302 +/- 0.16037   unsloth/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S

Some results changed very slightly, so I'm not sure if this is expect

@ggerganov
Copy link
Member Author

Yes, I think this difference is expected for SWA models (note Phi currently is disabled SWA, so no difference). It's caused by the different order in which we place the data in memory, due to the find_slot() updates. The results become identical with --swa-full - can you confirm?

@ngxson
Copy link
Collaborator

ngxson commented May 25, 2025

Yes that's right, I added --swa-full and now it become identical to master version:

OK:   Final estimate: PPL = 7.8002 +/- 0.17654   ggml-org/gemma-3-4b-it-GGUF:Q4_K_M
OK:   Final estimate: PPL = 37.7017 +/- 1.03468   bartowski/gemma-2-9b-it-GGUF:Q4_K_M
OK:   Final estimate: PPL = 5.9658 +/- 0.11216   lmstudio-community/Phi-3.1-mini-128k-instruct-GGUF:Q4_K_M
OK:   Final estimate: PPL = 5.2654 +/- 0.09581   bartowski/CohereForAI_c4ai-command-a-03-2025-GGUF:IQ1_M
OK:   Final estimate: PPL = 7.3320 +/- 0.16048   unsloth/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S

Edit: except for gemma-2-9b-it-GGUF

@rhvall

This comment was marked as resolved.

@ngxson
Copy link
Collaborator

ngxson commented May 26, 2025

I re-run the test and the ppl stays the same as my last comment.

Btw, just thinking, is it possible (and it is useful) to add a ppl test mode that uses the KV remove API?

@ggerganov
Copy link
Member Author

I re-run the test and the ppl stays the same as my last comment.

The bartowski/gemma-2-9b-it-GGUF:Q4_K_M model produces the same PPL on master and on this PR with this command:

./bin/llama-perplexity -hf bartowski/gemma-2-9b-it-GGUF:Q4_K_M -f ./wikitext-2-raw/wiki.test.raw -c 16384 -fa --chunks 2 --swa-full

Maybe your reference value on master is outdated?

Btw, just thinking, is it possible (and it is useful) to add a ppl test mode that uses the KV remove API?

Can you clarify?

@ngxson
Copy link
Collaborator

ngxson commented May 26, 2025

I can't run the ppl rn, but if you get correct result, then I think yes could be a problem on my side.

Btw, just thinking, is it possible (and it is useful) to add a ppl test mode that uses the KV remove API?

Can you clarify?

Currently, AFAIU the ppl test simply evaluate text chunk by chunk, but only going forward. For example, if I have 3 chunks: 1-2-3, then they will be evaluated in the order of 1-2-3

But what we also what to test is for example:

  • Evaluate chunk 1, 2
  • Remove chunk 2 from memory
  • Evaluate chunk 2, 3

So I expect the ppl to be the same as just doing 1-2-3

@slaren
Copy link
Member

slaren commented May 26, 2025

How does this recover from a failed call to graph_compute? What is the replacement for commit/restore?

@ggerganov
Copy link
Member Author

How does this recover from a failed call to graph_compute? What is the replacement for commit/restore?

There are some tricky scenarios in which we could have overwritten some of the data in the cache by the time the error occurs (i.e. processed the first few ubatches, but not all of them yet). Before (i.e. on master), we allowed to place ubatches only in empty slots, so we could simply mark the cells back to empty and recover in such cases. But with the new logic, this is no longer guaranteed because we allow to place ubatches in masked slots. This new logic is quite beneficial because it will enable smaller caches for SWA (i.e. n_swa + n_ubatch vs n_swa + n_batch) and also we don't have to explicitly prune SWA-masked tokens on successful batch, which allows to seamlessly do short rollbacks. The latter is needed for speculative decoding (#13747) and for cases where the last generated chat response can contain a few extra newlines, which are then discarded by the Web UI. In the latter case, if we pruned all tokens strictly by the SWA window (as it is currently on master), then this would cause full reprocessing of the context, while with the new logic, we can still rollback and have all necessary cache data available to reuse.

I think that on compute error, the KV cache should be assumed in an undefined state and the application should take necessary steps to recover (i.e. by clearing it and reprocessing the context that is currently needed). Later on, this reprocessing will become seamless, when we start storing the necessary tokens/embeddings information and add the logic for auto-reprocessing whatever is currently missing from the cache.

@slaren
Copy link
Member

slaren commented May 26, 2025

I am mostly concerned about the abort callback functionality. Errors in the backend are likely to be unrecoverable, but I am not sure if the abort functionality makes sense if it leaves the cache in a bad state.

@ggerganov
Copy link
Member Author

I admit that I had completely forgotten about the abort callback. Let me see if we can do something about this.

@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch from 0b73da5 to 2252eef Compare May 27, 2025 13:11
@ggerganov ggerganov marked this pull request as draft May 27, 2025 13:32
@ggerganov
Copy link
Member Author

Drafting for now as I want to do some more testing and think about the abort mechanism.

@gabe-l-hart gabe-l-hart mentioned this pull request May 27, 2025
Comment on lines +594 to +533
const llama_seq_id seq_id = ubatch.seq_id[i][0];

// can we use this cell? either:
// - the cell is empty
// - the cell is occupied only by one sequence:
// - mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos
bool can_use = cells.is_empty(head_cur + i);

if (!can_use && cells.seq_count(head_cur + i) == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would automatically disqualify all of the other logic around reusing full cells?

Assuming this is correct, I think this would be the correct approach?

Suggested change
const llama_seq_id seq_id = ubatch.seq_id[i][0];
// can we use this cell? either:
// - the cell is empty
// - the cell is occupied only by one sequence:
// - mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos
bool can_use = cells.is_empty(head_cur + i);
if (!can_use && cells.seq_count(head_cur + i) == 1) {
// can we use this cell? either:
// - the cell is empty
// - the cell is occupied only by one sequence:
// - mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos
bool can_use = cells.is_empty(head_cur + i);
if (!can_use && cells.seq_count(head_cur + i) == 1 && ubatch.n_seqs == 1) {
const llama_seq_id seq_id = ubatch.seq_id[0][0];

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That diff is gross, but it just adds an extra conditional to the outer check that checks whether ubatch.n_seqs == 1 and then always uses ubatch.seq_id[0][0].

Copy link
Collaborator

@compilade compilade May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gabe-l-hart It should not be necessary to limit this branch to when ubatch.n_seqs to 1. This almost never happens for simple splits anyway, except when n_ubatch is 1.

See #13746 (comment).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Race condition! Thanks thanks

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should simplify llama_ubatch by enforcing that the tokens in the batch belong to only one sequence. The use-case for multiple sequences per input token is very rare and can trivially be achieved with llama_kv_self_seq_cp() if needed. Hence I added these TODOs:

int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Simplifying the multi-sequence-per-token logic would certainly help from a clarity perspective (having recently tried to understand the current implementation and been only partially successful).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use-case for multiple sequences per input token is very rare

If the public llama_batch API is changed (or at least common_batch_add), I think the only places where multiple sequences per input token are used are in tools/perplexity/perplexity.cpp for hellaswag, winogrande and multiple-choice benchmarks (I don't know if any 3rd party project uses that feature (multiple sequences per input token), though).

I can say this will simplify part of the recurrent cache's find_slot (since it did attempt to handle multi-sequence tokens, at least enough to make hellaswag run properly).

Another reason to remove this is that it is not obvious what should happen when using multiple sequences for a new token when the sequences have already diverged. The current behavior is to use the first seq_id and overwrite the states of the other sequences part of that token (at least in for recurrent cache). I'm not sure how that case is handled for the unified cache, but this case is very hard to handle correctly (not even sure what the correct behavior should be here). (This case doesn't really happen in practice, though, since multiple sequences per input tokens are very rarely used, and also not in this way. But the problem is that they could be, and it leads to confusing behavior.)

@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch 2 times, most recently from 825efad to eed741e Compare May 29, 2025 13:22
@ggerganov
Copy link
Member Author

@slaren In eed741e I think I managed to extract the head and n_kv state from the KV cache object completely. It is now contained in the llama_memory_state objects.

We no longer pass the memory object when building the compute graphs. Instead, we prepare a memory state for each ubatch and we pass this state to the graph building context. The memory state carries the necessary information about the current head and n_kv.

I was also able to elegantly replace the llama_kv_cache::set_full() concept with llama_memory_state_ptr llama_kv_cache::init_full(); which makes more sense semantically and I plan to apply the same idea to replace the sched_defrag() method in a follow-up PR in order to extract the defrag_info state from the KV cache in a similar way.

Sorry for the large diff again. Let me know if you have any follow-up comments or suggestions.

Copy link
Member

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. Some notes for the future:

  • init_batch and init_full could be moved to llama_memory_i. There are still many places where llama_kv_cache is used directly instead of llama_memory_i. llama_kv_cache_recurrent probably should not inherit from llama_kv_cache, but rather it should be a completely separate implementation of llama_memory_i. I believe that's already the plan, and llama_kv_cache is only used in this way to simplify the transition, but ultimately functions like llama_decode should not depend on llama_kv_cache.

  • llama_kv_cache could probably be renamed to llama_kv_cache_i for consistency

  • From what I can tell, llama_kv_cache_unified_state_i, llama_kv_cache_unified_iswa_state_i, llama_kv_cache_recurrent_state_i do not need to be interfaces. Is the goal is to hide the implementation details from the header? Since virtual functions have a nonzero performance cost, I would be wary about turning everything into an interface and adding an indirection to every function call, even if it is not likely to have a significant performance impact at the moment.

@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch 3 times, most recently from 9d05381 to 2b984f4 Compare May 30, 2025 08:29
@ggerganov ggerganov marked this pull request as ready for review May 30, 2025 08:46
@ggerganov
Copy link
Member Author

ggerganov commented May 30, 2025

From what I can tell, llama_kv_cache_unified_state_i, llama_kv_cache_unified_iswa_state_i, llama_kv_cache_recurrent_state_i do not need to be interfaces.

Updated - these are no longer interfaces.

  • init_batch and init_full could be moved to llama_memory_i. There are still many places where llama_kv_cache is used directly instead of llama_memory_i. llama_kv_cache_recurrent probably should not inherit from llama_kv_cache, but rather it should be a completely separate implementation of llama_memory_i. I believe that's already the plan, and llama_kv_cache is only used in this way to simplify the transition, but ultimately functions like llama_decode should not depend on llama_kv_cache.
  • llama_kv_cache could probably be renamed to llama_kv_cache_i for consistency

Yes, the goal is to completely migrate to the llama_memory concept and hide the "KV" specifics from the APIs.


I set the PR ready for review and planning to merge this soon, unless we spot any regressions. The next short-term steps will be:

At this point, I think we should focus on refactoring the llama_batch logic and API.

@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part3 branch from f23e4cc to 71619f2 Compare May 31, 2025 07:05
@ggerganov ggerganov changed the title kv-cache : simplify kv-cache : refactor + add llama_memory_state_i May 31, 2025
@ggerganov ggerganov merged commit 12d0188 into master May 31, 2025
53 of 55 checks passed
@ggerganov ggerganov deleted the gg/kv-cache-simplify-part3 branch May 31, 2025 07:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants